from __future__ import annotations

import datetime
import functools
import json
import typing as t

from ansible.module_utils import _internal
from ansible.module_utils._internal import _messages
from ansible.module_utils._internal._datatag import (
    AnsibleSerializable,
    AnsibleSerializableWrapper,
    AnsibleTaggedObject,
    Tripwire,
    _AnsibleTaggedBytes,
    _AnsibleTaggedDate,
    _AnsibleTaggedDateTime,
    _AnsibleTaggedDict,
    _AnsibleTaggedFloat,
    _AnsibleTaggedInt,
    _AnsibleTaggedList,
    _AnsibleTaggedSet,
    _AnsibleTaggedStr,
    _AnsibleTaggedTime,
    _AnsibleTaggedTuple,
    AnsibleTagHelper,
    _tags,
)

# transformations to "final" JSON representations can only use:
# str, float, int, bool, None, dict, list
# NOT SUPPORTED: tuple, set -- the representation of these in JSON varies by profile (can raise an error, may be converted to list, etc.)
# This means that any special handling required on JSON types that are not wrapped/tagged must be done in a pre-pass before serialization.
# The final type map cannot contain any JSON types other than tuple or set.


_NoneType: t.Final[type] = type(None)

_json_subclassable_scalar_types: t.Final[tuple[type, ...]] = (str, float, int)
"""Scalar types understood by JSONEncoder which can also be subclassed."""

_json_scalar_types: t.Final[tuple[type, ...]] = (str, float, int, bool, _NoneType)
"""Scalar types understood by JSONEncoder."""

_json_container_types: t.Final[tuple[type, ...]] = (dict, list, tuple)
"""Container types understood by JSONEncoder."""

_json_types: t.Final[tuple[type, ...]] = _json_scalar_types + _json_container_types
"""Types understood by JSONEncoder."""

_intercept_containers = frozenset(
    {
        dict,
        list,
        tuple,
        _AnsibleTaggedDict,
        _AnsibleTaggedList,
        _AnsibleTaggedTuple,
    }
)
"""Container types to intercept in support of scalar interception."""

_common_module_types: frozenset[type[AnsibleSerializable]] = frozenset(
    {
        _AnsibleTaggedBytes,
        _AnsibleTaggedDate,
        _AnsibleTaggedDateTime,
        _AnsibleTaggedDict,
        _AnsibleTaggedFloat,
        _AnsibleTaggedInt,
        _AnsibleTaggedList,
        _AnsibleTaggedSet,
        _AnsibleTaggedStr,
        _AnsibleTaggedTime,
        _AnsibleTaggedTuple,
    }
)
"""
Types that must be supported for all Ansible module serialization profiles.

For module-to-controller, all types should support full fidelity serialization.
This allows infrastructure and library code to use these features even when a module does not.

For controller-to-module, type behavior is profile dependent.
"""

_common_module_response_types: frozenset[type[AnsibleSerializable]] = frozenset(
    {
        _messages.PluginInfo,
        _messages.PluginType,
        _messages.Event,
        _messages.EventChain,
        _messages.ErrorSummary,
        _messages.WarningSummary,
        _messages.DeprecationSummary,
        _tags.Deprecated,
    }
)
"""Types that must be supported for all Ansible module-to-controller serialization profiles."""

_T_encoder = t.TypeVar('_T_encoder', bound="AnsibleProfileJSONEncoder")
_T_decoder = t.TypeVar('_T_decoder', bound="AnsibleProfileJSONDecoder")


class _JSONSerializationProfile(t.Generic[_T_encoder, _T_decoder]):
    serialize_map: t.ClassVar[dict[type, t.Callable]]
    """
    Each concrete non-JSON type must be included in this mapping to support serialization.
    Including a JSON type in the mapping allows for overriding or disabling of serialization of that type.
    """

    deserialize_map: t.ClassVar[dict[str, t.Callable]]
    """A mapping of type keys to type dispatchers for deserialization."""

    allowed_ansible_serializable_types: t.ClassVar[frozenset[type[AnsibleSerializable]]] = frozenset()
    """Each concrete AnsibleSerialiable derived type must be included in this set to support serialization."""

    _common_discard_tags: t.ClassVar[dict[type, t.Callable]]
    """
    Serialize map for tagged types to have their tags discarded.
    This is generated by __init_subclass__ and should not be manually updated.
    """

    _allowed_type_keys: t.ClassVar[frozenset[str]]
    """
    The set of type keys allowed during deserialization.
    This is generated by __init_subclass__ and should not be manually updated.
    """

    _unwrapped_json_types: t.ClassVar[frozenset[type]]
    """
    The set of types that do not need to be wrapped during serialization.
    This is generated by __init_subclass__ and should not be manually updated.
    """

    profile_name: t.ClassVar[str]
    """
    The user-facing name of the profile, derived from the module name in which the profile resides.
    Used to load the profile dynamically at runtime.
    This is generated by __init_subclass__ and should not be manually updated.
    """

    encode_strings_as_utf8: t.ClassVar[bool] = False
    r"""
    When enabled, JSON encoding will result in UTF8 strings being emitted.
    Otherwise, non-ASCII strings will be escaped with `\uXXXX` escape sequences.`
    """

    @classmethod
    def pre_serialize(cls, encoder: _T_encoder, o: t.Any) -> t.Any:
        return o

    @classmethod
    def post_deserialize(cls, decoder: _T_decoder, o: t.Any) -> t.Any:
        return o

    @classmethod
    def cannot_serialize_error(cls, target: t.Any, /) -> t.NoReturn:
        raise TypeError(f'Object of type {type(target).__name__!r} is not JSON serializable by the {cls.profile_name!r} profile.')

    @classmethod
    def cannot_deserialize_error(cls, target_type_name: str, /) -> t.NoReturn:
        raise TypeError(f'Object of type {target_type_name!r} is not JSON deserializable by the {cls.profile_name!r} profile.')

    @classmethod
    def unsupported_target_type_error(cls, target_type_name: str, _value: dict) -> t.NoReturn:
        cls.cannot_deserialize_error(target_type_name)

    @classmethod
    def discard_tags(cls, value: AnsibleTaggedObject) -> object:
        return value._native_copy()

    @classmethod
    def deserialize_serializable(cls, value: dict[str, t.Any]) -> object:
        type_key = value[AnsibleSerializable._TYPE_KEY]

        if type_key not in cls._allowed_type_keys:
            cls.cannot_deserialize_error(type_key)

        return AnsibleSerializable._deserialize(value)

    @classmethod
    def serialize_as_list(cls, value: t.Iterable) -> list:
        # DTFIX-FUTURE: once we have separate control/data channels for module-to-controller (and back), warn about this conversion
        return AnsibleTagHelper.tag_copy(value, (item for item in value), value_type=list)

    @classmethod
    def serialize_as_isoformat(cls, value: datetime.date | datetime.time | datetime.datetime) -> str:
        return value.isoformat()

    @classmethod
    def serialize_serializable_object(cls, value: AnsibleSerializable) -> t.Any:
        return value._serialize()

    @classmethod
    def post_init(cls) -> None:
        pass

    @classmethod
    def maybe_wrap(cls, o: t.Any) -> t.Any:
        if type(o) in cls._unwrapped_json_types:
            return o

        return _WrappedValue(o)

    @classmethod
    def handle_key(cls, k: t.Any) -> t.Any:
        """Validation/conversion hook before a dict key is serialized. The default implementation only accepts str-typed keys."""
        # NOTE: Since JSON requires string keys, there is no support for preserving tags on dictionary keys during serialization.

        if not isinstance(k, str):  # DTFIX-FUTURE: optimize this to use all known str-derived types in type map / allowed types
            raise TypeError(f'Key of type {type(k).__name__!r} is not JSON serializable by the {cls.profile_name!r} profile.')

        return k

    @classmethod
    def _handle_key_str_fallback(cls, k: t.Any) -> t.Any:
        """Legacy implementations should use this key handler for backward compatibility with stdlib JSON key conversion quirks."""
        # DTFIX-FUTURE: optimized exact-type table lookup first

        if isinstance(k, str):
            return k

        if k is None or isinstance(k, (int, float)):
            return json.dumps(k)

        raise TypeError(f'Key of type {type(k).__name__!r} is not JSON serializable by the {cls.profile_name!r} profile.')

    @classmethod
    def default(cls, o: t.Any) -> t.Any:
        # Preserve the built-in JSON encoder support for subclasses of scalar types.

        if isinstance(o, _json_subclassable_scalar_types):
            return o

        # Preserve the built-in JSON encoder support for subclasses of dict and list.
        # Additionally, add universal support for mappings and sequences/sets by converting them to dict and list, respectively.

        if _internal.is_intermediate_mapping(o):
            return {cls.handle_key(k): cls.maybe_wrap(v) for k, v in o.items()}

        if _internal.is_intermediate_iterable(o):
            return [cls.maybe_wrap(v) for v in o]

        return cls.last_chance(o)

    @classmethod
    def last_chance(cls, o: t.Any) -> t.Any:
        if isinstance(o, Tripwire):
            o.trip()

        cls.cannot_serialize_error(o)

    def __init_subclass__(cls, **kwargs) -> None:
        cls.deserialize_map = {}
        cls._common_discard_tags = {obj: cls.discard_tags for obj in _common_module_types if issubclass(obj, AnsibleTaggedObject)}

        cls.post_init()

        cls.profile_name = cls.__module__.rsplit('.', maxsplit=1)[-1].lstrip('_')

        wrapper_types = set(obj for obj in cls.serialize_map.values() if isinstance(obj, type) and issubclass(obj, AnsibleSerializableWrapper))

        cls.allowed_ansible_serializable_types |= wrapper_types

        # no current need to preserve tags on controller-only types or custom behavior for anything in `allowed_serializable_types`
        cls.serialize_map.update({obj: cls.serialize_serializable_object for obj in cls.allowed_ansible_serializable_types})
        cls.serialize_map.update({obj: func for obj, func in _internal.get_controller_serialize_map().items() if obj not in cls.serialize_map})

        cls.deserialize_map[AnsibleSerializable._TYPE_KEY] = cls.deserialize_serializable  # always recognize tagged types

        cls._allowed_type_keys = frozenset(obj._type_key for obj in cls.allowed_ansible_serializable_types)

        cls._unwrapped_json_types = frozenset(
            {obj for obj in cls.serialize_map if not issubclass(obj, _json_types)}  # custom types that do not extend JSON-native types
            | {obj for obj in _json_scalar_types if obj not in cls.serialize_map}  # JSON-native scalars lacking custom handling
        )


class _WrappedValue:
    __slots__ = ('wrapped',)

    def __init__(self, wrapped: t.Any) -> None:
        self.wrapped = wrapped


class AnsibleProfileJSONEncoder(json.JSONEncoder):
    """Profile based JSON encoder capable of handling Ansible internal types."""

    _wrap_container_types = (list, set, tuple, dict)
    _profile: type[_JSONSerializationProfile]

    profile_name: str

    def __init__(self, **kwargs):
        self._wrap_types = self._wrap_container_types + (AnsibleSerializable,)

        if self._profile.encode_strings_as_utf8:
            kwargs.update(ensure_ascii=False)

        super().__init__(**kwargs)

    def __init_subclass__(cls, **kwargs) -> None:
        cls.profile_name = cls._profile.profile_name

    def encode(self, o):
        o = self._profile.maybe_wrap(self._profile.pre_serialize(self, o))

        return super().encode(o)

    def default(self, o: t.Any) -> t.Any:
        o_type = type(o)

        if o_type is _WrappedValue:  # pylint: disable=unidiomatic-typecheck
            o = o.wrapped
            o_type = type(o)

        if mapped_callable := self._profile.serialize_map.get(o_type):
            return self._profile.maybe_wrap(mapped_callable(o))

        # This is our last chance to intercept the values in containers, so they must be wrapped here.
        # Only containers natively understood by the built-in JSONEncoder are recognized, since any other container types must be present in serialize_map.

        if o_type is dict:  # pylint: disable=unidiomatic-typecheck
            return {self._profile.handle_key(k): self._profile.maybe_wrap(v) for k, v in o.items()}

        if o_type is list or o_type is tuple:  # pylint: disable=unidiomatic-typecheck
            return [self._profile.maybe_wrap(v) for v in o]  # JSONEncoder converts tuple to a list, so just make it a list now

        # Any value here is a type not explicitly handled by this encoder.
        # The profile default handler is responsible for generating an error or converting the value to a supported type.

        return self._profile.default(o)


class AnsibleProfileJSONDecoder(json.JSONDecoder):
    """Profile based JSON decoder capable of handling Ansible internal types."""

    _profile: type[_JSONSerializationProfile]

    profile_name: str

    def __init__(self, **kwargs):
        kwargs.update(object_hook=self.object_hook)

        super().__init__(**kwargs)

    def __init_subclass__(cls, **kwargs) -> None:
        cls.profile_name = cls._profile.profile_name

    def raw_decode(self, s: str, idx: int = 0) -> tuple[t.Any, int]:
        obj, end = super().raw_decode(s, idx)

        if _string_encoding_check_enabled():
            try:
                _recursively_check_string_encoding(obj)
            except UnicodeEncodeError as ex:
                raise _create_encoding_check_error() from ex

        obj = self._profile.post_deserialize(self, obj)

        return obj, end

    def object_hook(self, pairs: dict[str, object]) -> object:
        if _string_encoding_check_enabled():
            try:
                for key, value in pairs.items():
                    key.encode()
                    _recursively_check_string_encoding(value)
            except UnicodeEncodeError as ex:
                raise _create_encoding_check_error() from ex

        for mapped_key, mapped_callable in self._profile.deserialize_map.items():
            if mapped_key in pairs:
                return mapped_callable(pairs)

        return pairs


_check_encoding_setting = 'MODULE_STRICT_UTF8_RESPONSE'
r"""
The setting to control whether strings are checked to verify they can be encoded as valid UTF8.
This is currently only used during deserialization, to prevent string values from entering the controller which will later fail to be encoded as bytes.

The encoding failure can occur when the string represents one of two kinds of values:
1) It was created through decoding bytes with the `surrogateescape` error handler, and that handler is not being used when encoding.
2) It represents an invalid UTF8 value, such as `"\ud8f3"` in a JSON payload. This cannot be encoded, even using the `surrogateescape` error handler.

Although this becomes an error during deserialization, there are other opportunities for these values to become strings within Ansible.
Future code changes should further restrict bytes to string conversions to eliminate use of `surrogateescape` where appropriate.
Additional warnings at other boundaries may be needed to give users an opportunity to resolve the issues before they become errors.
"""
# DTFIX-FUTURE: add strict UTF8 string encoding checking to serialization profiles (to match the checks performed during deserialization)
# DTFIX3: the surrogateescape note above isn't quite right, for encoding use surrogatepass, which does work
# DTFIX-FUTURE: this config setting should probably be deprecated


def _create_encoding_check_error() -> Exception:
    """
    Return an AnsibleError for use when a UTF8 string encoding check has failed.
    These checks are only performed in the controller context, but since this is module_utils code, dynamic loading of the `errors` module is required.
    """
    errors = _internal.import_controller_module('ansible.errors')  # bypass AnsiballZ import scanning

    return errors.AnsibleRuntimeError(
        message='Refusing to deserialize an invalid UTF8 string value.',
        help_text=f'This check can be disabled with the `{_check_encoding_setting}` setting.',
    )


@functools.lru_cache
def _string_encoding_check_enabled() -> bool:
    """Return True if JSON deserialization should verify strings can be encoded as valid UTF8."""
    if constants := _internal.import_controller_module('ansible.constants'):  # bypass AnsiballZ import scanning
        return constants.config.get_config_value(_check_encoding_setting)  # covers all profile-based deserializers, not just modules

    return False


def _recursively_check_string_encoding(value: t.Any) -> None:
    """Recursively check the given object to ensure all strings can be encoded as valid UTF8."""
    value_type = type(value)

    if value_type is str:
        value.encode()
    elif value_type is list:  # dict is handled by the JSON deserializer
        for item in value:
            _recursively_check_string_encoding(item)
